import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import os
import os.path
import pandas as pd # type: ignore


def midline_and_POI_box_model(box_model_prefix, x_coor):
    csv_box_model_mid_line = 'comsol_raw/'+box_model_prefix+'_mid_line.csv'
    csv_box_model_mid_line_OoC = 'comsol_raw/'+box_model_prefix+'_mid_line_OoC.csv'
    
    csv_poi_results = 'csv/'+box_model_prefix+'_POI_bm_and_ooc_results.csv'
    csv_bm_results = 'csv/'+box_model_prefix+'_bm_results.csv'
    csv_freq_map_results = 'csv/'+box_model_prefix+'_max_location_results.csv'

    raw_bm_df = pd.read_csv(csv_box_model_mid_line, skiprows = 7, index_col = [0])
    raw_bm_df.sort_index(axis = 0, level = 0, ascending = True, inplace = True)

    freq_strs = list(raw_bm_df.columns.get_level_values(0))
    freqs = sorted(set([int(i.split('freq=')[1]) for i in freq_strs]))

    X = np.asarray(sorted(list(raw_bm_df.index.get_level_values(0))))
    columns = pd.MultiIndex.from_product([freqs, ['bm_z_mag_nm', 'bm_z_phase', 'bm_z_k_rad_m']])
    new_bm_df = pd.DataFrame(index = X,  columns = columns)

    
    ooc_mid_line_data = False
    if os.path.exists(csv_box_model_mid_line_OoC):
        ooc_mid_line_data = True
        csv_ooc_results = 'csv/'+box_model_prefix+'_ooc_results.csv'        

        raw_ooc_df = pd.read_csv(csv_box_model_mid_line_OoC, skiprows = 7, index_col = [0])
        raw_ooc_df.sort_index(axis = 0, level = 0, ascending = True, inplace = True)        

        X_ooc = np.asarray(sorted(list(raw_ooc_df.index.get_level_values(0))))
        columns = pd.MultiIndex.from_product([freqs, ['ooc_z_mag_nm', 'ooc_z_phase']])
        new_ooc_df = pd.DataFrame(index = X_ooc,  columns = columns)


    new_df_atPOI = pd.DataFrame(index = freqs, columns = ['bm_z_mag_nm', 'bm_z_k_rad_m', 'bm_z_phase', 'bm_gain','ooc_z_mag_nm', 'ooc_z_phase', 'ooc_gain'])
    new_df_atPOI.index.name = "freq"

    freq_map_df = pd.DataFrame(index = freqs, columns = ['max_location'])
    

    ########################################
    # BM and OoC z_mag and phase
    X = np.asarray(raw_bm_df.index.get_level_values(0))
    if ooc_mid_line_data:
        X_ooc = np.asarray(raw_ooc_df.index.get_level_values(0))

    k_atPOI = np.zeros(len(freqs))
    mag_atPOI = np.zeros(len(freqs))
    phase_atPOI = np.zeros(len(freqs))
    gain_atPOI = np.zeros(len(freqs))

    if ooc_mid_line_data:
        mag_ooc_atPOI = np.zeros(len(freqs))
        phase_ooc_atPOI = np.zeros(len(freqs))
        gain_ooc_atPOI = np.zeros(len(freqs))
    
    max_locations = np.zeros(len(freqs))
    
    if not ooc_mid_line_data:
        fig, ax = plt.subplots(3, sharex='col', sharey='row')
    else:
        fig, ax = plt.subplots(3, 2, sharex='all', sharey='row', figsize = (10,8))
    
    fi = 0
    for freqi in freqs:
        mag = np.asarray(raw_bm_df.loc[:, 'solid.uAmpZ (mm) @ freq='+str(freqi)])*10**6
        phase = np.asarray(raw_bm_df.loc[:, 'solid.uPhaseZ (rad) @ freq='+str(freqi)])
        gain = np.asarray(raw_bm_df.loc[:, 'solid.uAmp_tZ/u0 (1) @ freq='+str(freqi)])

        max_locations[fi] = X[np.argmax(mag)]
        freq_map_df.loc[freqi, 'max_location'] = max_locations[fi]

        for i in range(1, len(X)):
            while np.abs(phase[i] - phase[i-1]) > np.abs(phase[i] - 2*np.pi - phase[i-1]):
                phase[i] = phase[i] - 2*np.pi
        
        # phase = phase - phase[0]

        dph_dx = np.zeros(len(X)-2)
        for i in range(1, len(X)-1):
            dph_dx[i-1] = -(phase[i+1] - phase[i-1])/(X[i+1]-X[i-1])

        new_bm_df.loc[slice(None), (freqi, 'bm_z_mag_nm')] = mag
        new_bm_df.loc[slice(None), (freqi, 'bm_z_phase')] = phase
        new_bm_df.loc[X[1:-1], (freqi, 'bm_z_k_rad_m')] = dph_dx*1000

        if ooc_mid_line_data:
            ooc_mag = np.asarray(raw_ooc_df.loc[:, 'solid.uAmpZ (mm) @ freq='+str(freqi)])*10**6
            ooc_phase = np.asarray(raw_ooc_df.loc[:, 'solid.uPhaseZ (rad) @ freq='+str(freqi)])
            ooc_gain = np.asarray(raw_ooc_df.loc[:, 'solid.uAmp_tZ/u0 (1) @ freq='+str(freqi)])

            for i in range(1, len(X_ooc)):
                while np.abs(ooc_phase[i] - ooc_phase[i-1]) > np.abs(ooc_phase[i] - 2*np.pi - ooc_phase[i-1]):
                    ooc_phase[i] = ooc_phase[i] - 2*np.pi
            
            ooc_phase = ooc_phase - ooc_phase[1]

            new_ooc_df.loc[slice(None), (freqi, 'ooc_z_mag_nm')] = ooc_mag
            new_ooc_df.loc[slice(None), (freqi, 'ooc_z_phase')] = ooc_phase
        
        if not ooc_mid_line_data:
            ax[0].plot(X, mag)
            ax[0].set_ylabel('z mag (nm)')
            # ax[0].set_yscale('log')
            ax[1].plot(X, gain)
            ax[1].set_yscale('log')
            ax[1].set_ylim(0.01, 20)
            ax[1].set_ylabel('velocity gain')
            ax[2].plot(X, phase)
            ax[2].set_ylabel('phase (rad)')
            ax[2].set_xlabel('location (mm)')
        else:
            ax[0,0].set_title('BM')
            ax[0,0].plot(X, mag, label = str(freqi/1000)+' Hz')
            ax[0,0].set_ylabel('z mag (nm)')
            # ax[0,0].set_yscale('log')
            ax[1,0].plot(X, gain)
            ax[1,0].set_yscale('log')
            ax[1,0].set_ylim(0.01, 20)
            ax[1,0].set_ylabel('velocity gain')
            ax[2,0].plot(X, phase)
            ax[2,0].set_ylabel('phase (rad)')
            ax[0,1].set_title('OoC')
            ax[0,1].plot(X_ooc, ooc_mag)
            # ax[0,1].set_yscale('log')
            ax[1,1].plot(X_ooc, ooc_gain)
            ax[1,1].set_yscale('log')
            ax[2,1].plot(X_ooc, ooc_phase)
            ax[2,0].set_xlabel('location (mm)')
            ax[2,1].set_xlabel('location (mm)')

        #### POI BM motion ####  
        xi = 0
        xii = 0
        for i in range(len(X)):
            if xi == 0 and X[i] > x_coor - 0.01:
                xi = i
            if xi != 0 and xii == 0 and X[i] > x_coor + 0.01:
                xii = i
                break

        k_atPOI[fi] = np.mean(dph_dx[xi:xii])*1000
        mag_atPOI[fi] = np.mean(mag[xi:xii])
        phase_atPOI[fi] = np.mean(phase[xi:xii])
        gain_atPOI[fi] = np.mean(gain[xi:xii])

        #### POI OoC motion ####  
        if ooc_mid_line_data:
            xi = 0
            xii = 0
            for i in range(len(X_ooc)):
                if xi == 0 and X_ooc[i] > x_coor - 0.03:
                    xi = i
                if xi != 0 and xii == 0 and X_ooc[i] > x_coor + 0.03:
                    xii = i
                    break

            mag_ooc_atPOI[fi] = np.mean(ooc_mag[xi:xii])
            phase_ooc_atPOI[fi] = np.mean(ooc_phase[xi:xii])
            gain_ooc_atPOI[fi] = np.mean(ooc_gain[xi:xii])

        #########################

        fi = fi+1
    
    print(mag_atPOI)
    
    #### export to csv ###
    new_bm_df.sort_index(inplace = True)
    new_bm_df.to_csv(csv_bm_results)
    new_bm_df.to_csv(csv_bm_results)
    freq_map_df.to_csv(csv_freq_map_results)
    
    if ooc_mid_line_data:
        new_ooc_df.sort_index(inplace = True)
        new_ooc_df.to_csv(csv_ooc_results)
        new_ooc_df.to_csv(csv_ooc_results)

    #### save figure ###
    fig.suptitle(box_model_prefix)
    fig.tight_layout()
    pdf_results = 'plots/'+ box_model_prefix+'_mid_line_results.pdf'
    plt.savefig(pdf_results, transparent=True)
    # plt.close()
    plt.show()

    ######################## Frequency-placement map #############################
    nankali_df = pd.read_csv('csv/nankali_2022_map.csv', index_col= ['x (%)']).sort_index()
    nankali_x = nankali_df.index*5.8
    nankali_freq = nankali_df.loc[:, 'frequency (kHz)']*1000

    muller_x = np.arange(0, 5.8, 0.1)
    muller_freq = 10**((156.5 -muller_x/5.8*100)/82.5)*1000

    fig, ax = plt.subplots(1, sharex='col', sharey='row')
    ax.plot(muller_x, muller_freq, label = 'muller 2005 neural data')
    ax.plot(muller_x, muller_freq/np.sqrt(2), label = 'shifted muller 2005')
    ax.plot(nankali_x, nankali_freq, marker = '^', label = 'nankali 2022 bm data', linestyle = 'None')
    ax.plot(max_locations, freqs, label = 'model results', marker = 's')
    ax.legend()
    ax.set_yscale('log')
    ax.set_xlabel('location (mm)')
    ax.set_ylabel('frequency (kHz)')
    ax.set_yticks([4000, 5000, 10000, 30000, 50000])
    ax.set_yticklabels(['4','5', '10','30','50'])
    fig.suptitle(box_model_prefix)
    pdf_results = 'plots/'+ box_model_prefix+'_map.pdf'
    fig.suptitle(box_model_prefix)
    plt.savefig(pdf_results, transparent=True)
    plt.show()
    
    ######################## export POI results #############################
    new_df_atPOI.loc[:, 'bm_z_phase'] = phase_atPOI
    new_df_atPOI.loc[:, 'bm_k'] = k_atPOI
    new_df_atPOI.loc[:, 'bm_z_mag_nm'] = mag_atPOI
    new_df_atPOI.loc[:, 'bm_gain'] = gain_atPOI
    
    if ooc_mid_line_data:
        new_df_atPOI.loc[:, 'ooc_z_mag_nm'] = mag_ooc_atPOI
        new_df_atPOI.loc[:, 'ooc_z_phase'] = phase_ooc_atPOI
        new_df_atPOI.loc[:, 'ooc_gain'] = gain_ooc_atPOI

    new_df_atPOI.sort_index(inplace = True)
    new_df_atPOI.to_csv(csv_poi_results)